# assembly components
## Convolution + Batch Normalization
ConvBNLayer {outChannels, kernel, stride, bnTimeConst} = Sequential(
    ConvolutionalLayer {outChannels, kernel, init = "heNormal", stride = stride, pad = true, bias = false} :
    BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = bnTimeConst, useCntkEngine = false}
)

## Convolution + Batch Normalization + Rectifier Linear
ConvBNReLULayer {outChannels, kernelSize, stride, bnTimeConst} = Sequential(
    ConvBNLayer {outChannels, kernelSize, stride, bnTimeConst} :
    ReLU
)

# ResNet components

# The basic ResNet block contains two 3x3 convolutions, which is added to the orignal input 
# of the block.  
ResNetBasic {outChannels, bnTimeConst} = {
    apply (x) = {
        # Convolution
        b = Sequential (
            ConvBNReLULayer {outChannels, (3:3), (1:1), bnTimeConst} : 
            ConvBNLayer {outChannels, (3:3), (1:1), bnTimeConst}) (x) 

        p = Plus(b, x)
        r = ReLU(p)
    }.r
}.apply

# A block to reduce the feature map resolution. Two 3x3 convolutions with stride, which is
# added to the original input with 1x1 convolution and stride 
ResNetBasicInc {outChannels, stride, bnTimeConst} = {
    apply (x) = {
        # Convolution 
        b = Sequential (
            ConvBNReLULayer {outChannels, (3:3), stride, bnTimeConst} :
            ConvBNLayer {outChannels, (3:3), (1:1), bnTimeConst}) (x)

        # Shortcut
        s = ConvBNLayer {outChannels, (1:1), stride, bnTimeConst} (x)

        p = Plus(b, s)
        r = ReLU(p)
    }.r
}.apply

# A bottleneck ResNet block is attempting to reduce the amount of computation by replacing
# the two 3x3 convolutions by a 1x1 convolution, bottlenecked to `interOutChannels` feature 
# maps (usually interOutChannels < outChannels, thus the name bottleneck), followed by a 
# 3x3 convolution, and then a 1x1 convolution again, with `outChannels` feature maps. 
ResNetBottleneck {outChannels, interOutChannels, bnTimeConst} = {
    apply (x) = {
        # Convolution
        b = Sequential (
            ConvBNReLULayer {interOutChannels, (1:1), (1:1), bnTimeConst} :
            ConvBNReLULayer {interOutChannels, (3:3), (1:1), bnTimeConst} :
            ConvBNLayer {outChannels, (1:1), (1:1), bnTimeConst}) (x)

        p = Plus(b, x)
        r = ReLU(p)
    }.r
}.apply

# a block to reduce the feature map resolution using bottleneck. One can reduce the size 
# either at the first 1x1 convolution by specifying "stride1x1=(2:2)" (original paper), 
# or at the 3x3 convolution by specifying "stride3x3=(2:2)" (Facebook re-implementation). 
ResNetBottleneckInc {outChannels, interOutChannels, stride1x1, stride3x3, bnTimeConst} = {
    apply (x) = {
        # Convolution
        b = Sequential (
            ConvBNReLULayer {interOutChannels, (1:1), stride1x1, bnTimeConst} :
            ConvBNReLULayer {interOutChannels, (3:3), stride3x3, bnTimeConst} :
            ConvBNLayer {outChannels, (1:1), (1:1), bnTimeConst}) (x)

        # Shortcut
        stride[i:0..Length(stride1x1)-1] = stride1x1[i] * stride3x3[i]
        s = ConvBNLayer {outChannels, (1:1), stride, bnTimeConst} (x)

        p = Plus(b, s)
        r = ReLU(p)
    }.r
}.apply

NLayerStack {n, c} = Sequential (array[0..n-1] (c))
ResNetBasicStack {n, outChannels, bnTimeConst} = NLayerStack {n, i => ResNetBasic {outChannels, bnTimeConst}}
ResNetBottleneckStack {n, outChannels, interOutChannels, bnTimeConst} = NLayerStack {n, i => ResNetBottleneck {outChannels, interOutChannels, bnTimeConst}}